{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.1"
    },
    "toc": {
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "mlp-basic.ipynb",
      "provenance": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gE7RNWDjLl1o",
        "colab_type": "text"
      },
      "source": [
        "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
        "- Author: Sebastian Raschka\n",
        "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pAaXIbUhLnVT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install -q IPython\n",
        "!pip install -q ipykernel\n",
        "!pip install -q watermark\n",
        "!pip install -q matplotlib\n",
        "!pip install -q tensorwatch\n",
        "!pip install -q sklearn\n",
        "!pip install -q pandas\n",
        "!pip install -q pydot\n",
        "!pip install -q hiddenlayer\n",
        "!pip install -q graphviz"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CeX28YdYLl1q",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        },
        "outputId": "015b0f0d-cd7b-41a6-bcae-e0a9c243222d"
      },
      "source": [
        "%load_ext watermark\n",
        "%watermark -a 'Sebastian Raschka' -v -p torch"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sebastian Raschka \n",
            "\n",
            "CPython 3.6.9\n",
            "IPython 5.5.0\n",
            "\n",
            "torch 1.5.1+cu101\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61ZoqACKLl11",
        "colab_type": "text"
      },
      "source": [
        "- Runs on CPU or GPU (if available)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-2d6qf1HLl12",
        "colab_type": "text"
      },
      "source": [
        "# Model Zoo -- Multilayer Perceptron"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AP2yc5b_Ll13",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uFPxT2CSLl14",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import time\n",
        "import numpy as np\n",
        "from torchvision import datasets\n",
        "from torchvision import transforms\n",
        "from torch.utils.data import DataLoader\n",
        "import torch.nn.functional as F\n",
        "import torch\n",
        "\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    torch.backends.cudnn.deterministic = True"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O5mxzXlwLl18",
        "colab_type": "text"
      },
      "source": [
        "## Settings and Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ry66I9aJLl19",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "efcd0ccc-1913-4563-b806-138ecc8a92bf"
      },
      "source": [
        "##########################\n",
        "### SETTINGS\n",
        "##########################\n",
        "\n",
        "# Device\n",
        "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "# Hyperparameters\n",
        "random_seed = 1\n",
        "learning_rate = 0.1\n",
        "num_epochs = 10\n",
        "batch_size = 64\n",
        "\n",
        "# Architecture\n",
        "num_features = 784\n",
        "num_hidden_1 = 128\n",
        "num_hidden_2 = 256\n",
        "num_classes = 10\n",
        "\n",
        "\n",
        "##########################\n",
        "### MNIST DATASET\n",
        "##########################\n",
        "\n",
        "# Note transforms.ToTensor() scales input images\n",
        "# to 0-1 range\n",
        "train_dataset = datasets.MNIST(root='data', \n",
        "                               train=True, \n",
        "                               transform=transforms.ToTensor(),\n",
        "                               download=True)\n",
        "\n",
        "test_dataset = datasets.MNIST(root='data', \n",
        "                              train=False, \n",
        "                              transform=transforms.ToTensor())\n",
        "\n",
        "\n",
        "train_loader = DataLoader(dataset=train_dataset, \n",
        "                          batch_size=batch_size, \n",
        "                          shuffle=True)\n",
        "\n",
        "test_loader = DataLoader(dataset=test_dataset, \n",
        "                         batch_size=batch_size, \n",
        "                         shuffle=False)\n",
        "\n",
        "# Checking the dataset\n",
        "for images, labels in train_loader:  \n",
        "    print('Image batch dimensions:', images.shape)\n",
        "    print('Image label dimensions:', labels.shape)\n",
        "    break"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Image batch dimensions: torch.Size([64, 1, 28, 28])\n",
            "Image label dimensions: torch.Size([64])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cKwnbpTkLl2B",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##########################\n",
        "### MODEL\n",
        "##########################\n",
        "\n",
        "class MultilayerPerceptron(torch.nn.Module):\n",
        "\n",
        "    def __init__(self, num_features, num_classes):\n",
        "        super(MultilayerPerceptron, self).__init__()\n",
        "        \n",
        "        ### 1st hidden layer\n",
        "        self.linear_1 = torch.nn.Linear(num_features, num_hidden_1)\n",
        "        # The following two lines are not necessary, \n",
        "        # but used here to demonstrate how to access the weights\n",
        "        # and use a different weight initialization.\n",
        "        # By default, PyTorch uses Xavier/Glorot initialization, which\n",
        "        # should usually be preferred.\n",
        "        self.linear_1.weight.detach().normal_(0.0, 0.1)\n",
        "        self.linear_1.bias.detach().zero_()\n",
        "        \n",
        "        ### 2nd hidden layer\n",
        "        self.linear_2 = torch.nn.Linear(num_hidden_1, num_hidden_2)\n",
        "        self.linear_2.weight.detach().normal_(0.0, 0.1)\n",
        "        self.linear_2.bias.detach().zero_()\n",
        "        \n",
        "        ### Output layer\n",
        "        self.linear_out = torch.nn.Linear(num_hidden_2, num_classes)\n",
        "        self.linear_out.weight.detach().normal_(0.0, 0.1)\n",
        "        self.linear_out.bias.detach().zero_()\n",
        "        \n",
        "    def forward(self, x):\n",
        "        out = self.linear_1(x)\n",
        "        out = F.relu(out)\n",
        "        out = self.linear_2(out)\n",
        "        out = F.relu(out)\n",
        "        logits = self.linear_out(out)\n",
        "        probas = F.log_softmax(logits, dim=1)\n",
        "        return logits, probas\n",
        "\n",
        "    \n",
        "torch.manual_seed(random_seed)\n",
        "model = MultilayerPerceptron(num_features=num_features,\n",
        "                             num_classes=num_classes)\n",
        "\n",
        "model = model.to(device)\n",
        "\n",
        "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  "
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZVbpGQllUS8K",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 176
        },
        "outputId": "aa1c2768-982e-4575-ea00-732452288fb9"
      },
      "source": [
        "import hiddenlayer as hl\n",
        "hl.build_graph(model, torch.zeros([64, 28*28]).to(device))"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<hiddenlayer.graph.Graph at 0x7f52589f2e80>"
            ],
            "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"468pt\" height=\"116pt\"\n viewBox=\"0.00 0.00 468.00 116.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 80)\">\n<title>%3</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-72,36 -72,-80 396,-80 396,36 -72,36\"/>\n<!-- /outputs/11 -->\n<g id=\"node1\" class=\"node\">\n<title>/outputs/11</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"194,-40 140,-40 140,-4 194,-4 194,-40\"/>\n<text text-anchor=\"start\" x=\"154\" y=\"-19\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n</g>\n<!-- /outputs/12 -->\n<g id=\"node2\" class=\"node\">\n<title>/outputs/12</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"324,-40 257,-40 257,-4 324,-4 324,-40\"/>\n<text text-anchor=\"start\" x=\"265.5\" y=\"-19\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">LogSoftmax</text>\n</g>\n<!-- /outputs/11&#45;&gt;/outputs/12 -->\n<g id=\"edge1\" class=\"edge\">\n<title>/outputs/11&#45;&gt;/outputs/12</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M194.1026,-22C209.535,-22 229.2361,-22 246.8008,-22\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"246.8475,-25.5001 256.8474,-22 246.8474,-18.5001 246.8475,-25.5001\"/>\n<text text-anchor=\"middle\" x=\"225.5\" y=\"-25\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">64x10</text>\n</g>\n<!-- 11352692832848273995 -->\n<g id=\"node3\" class=\"node\">\n<title>11352692832848273995</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"72,-44 0,-44 0,0 72,0 72,-44\"/>\n<text text-anchor=\"start\" x=\"8\" y=\"-28\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear &gt; Relu</text>\n<text text-anchor=\"start\" x=\"57\" y=\"-7\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">x2</text>\n</g>\n<!-- 11352692832848273995&#45;&gt;/outputs/11 -->\n<g id=\"edge2\" class=\"edge\">\n<title>11352692832848273995&#45;&gt;/outputs/11</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M72.1431,-22C89.9959,-22 111.5175,-22 129.5449,-22\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"129.7013,-25.5001 139.7013,-22 129.7013,-18.5001 129.7013,-25.5001\"/>\n<text text-anchor=\"middle\" x=\"106\" y=\"-25\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">64x256</text>\n</g>\n</g>\n</svg>\n"
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hkB2CxijLl2G",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "21cd93f1-888b-4fba-ef78-d823e4c2bfb6"
      },
      "source": [
        "def compute_accuracy(net, data_loader):\n",
        "    net.eval()\n",
        "    correct_pred, num_examples = 0, 0\n",
        "    with torch.no_grad():\n",
        "        for features, targets in data_loader:\n",
        "            features = features.view(-1, 28*28).to(device)\n",
        "            targets = targets.to(device)\n",
        "            logits, probas = net(features)\n",
        "            _, predicted_labels = torch.max(probas, 1)\n",
        "            num_examples += targets.size(0)\n",
        "            correct_pred += (predicted_labels == targets).sum()\n",
        "        return correct_pred.float()/num_examples * 100\n",
        "    \n",
        "\n",
        "start_time = time.time()\n",
        "for epoch in range(num_epochs):\n",
        "    model.train()\n",
        "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
        "        \n",
        "        features = features.view(-1, 28*28).to(device)\n",
        "        targets = targets.to(device)\n",
        "            \n",
        "        ### FORWARD AND BACK PROP\n",
        "        logits, probas = model(features)\n",
        "        cost = F.cross_entropy(logits, targets)\n",
        "        optimizer.zero_grad()\n",
        "        \n",
        "        cost.backward()\n",
        "        \n",
        "        ### UPDATE MODEL PARAMETERS\n",
        "        optimizer.step()\n",
        "        \n",
        "        ### LOGGING\n",
        "        if not batch_idx % 50:\n",
        "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
        "                   %(epoch+1, num_epochs, batch_idx, \n",
        "                     len(train_loader), cost))\n",
        "\n",
        "    with torch.set_grad_enabled(False):\n",
        "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
        "              epoch+1, num_epochs, \n",
        "              compute_accuracy(model, train_loader)))\n",
        "        \n",
        "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
        "    \n",
        "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 001/010 | Batch 000/938 | Cost: 2.4231\n",
            "Epoch: 001/010 | Batch 050/938 | Cost: 0.7213\n",
            "Epoch: 001/010 | Batch 100/938 | Cost: 0.3137\n",
            "Epoch: 001/010 | Batch 150/938 | Cost: 0.4774\n",
            "Epoch: 001/010 | Batch 200/938 | Cost: 0.3311\n",
            "Epoch: 001/010 | Batch 250/938 | Cost: 0.3081\n",
            "Epoch: 001/010 | Batch 300/938 | Cost: 0.3944\n",
            "Epoch: 001/010 | Batch 350/938 | Cost: 0.2296\n",
            "Epoch: 001/010 | Batch 400/938 | Cost: 0.0828\n",
            "Epoch: 001/010 | Batch 450/938 | Cost: 0.2247\n",
            "Epoch: 001/010 | Batch 500/938 | Cost: 0.2467\n",
            "Epoch: 001/010 | Batch 550/938 | Cost: 0.1541\n",
            "Epoch: 001/010 | Batch 600/938 | Cost: 0.2409\n",
            "Epoch: 001/010 | Batch 650/938 | Cost: 0.4143\n",
            "Epoch: 001/010 | Batch 700/938 | Cost: 0.2905\n",
            "Epoch: 001/010 | Batch 750/938 | Cost: 0.3491\n",
            "Epoch: 001/010 | Batch 800/938 | Cost: 0.1403\n",
            "Epoch: 001/010 | Batch 850/938 | Cost: 0.0682\n",
            "Epoch: 001/010 | Batch 900/938 | Cost: 0.2784\n",
            "Epoch: 001/010 training accuracy: 94.88%\n",
            "Time elapsed: 0.23 min\n",
            "Epoch: 002/010 | Batch 000/938 | Cost: 0.3281\n",
            "Epoch: 002/010 | Batch 050/938 | Cost: 0.2818\n",
            "Epoch: 002/010 | Batch 100/938 | Cost: 0.0917\n",
            "Epoch: 002/010 | Batch 150/938 | Cost: 0.2296\n",
            "Epoch: 002/010 | Batch 200/938 | Cost: 0.2237\n",
            "Epoch: 002/010 | Batch 250/938 | Cost: 0.1611\n",
            "Epoch: 002/010 | Batch 300/938 | Cost: 0.0540\n",
            "Epoch: 002/010 | Batch 350/938 | Cost: 0.2380\n",
            "Epoch: 002/010 | Batch 400/938 | Cost: 0.0532\n",
            "Epoch: 002/010 | Batch 450/938 | Cost: 0.0970\n",
            "Epoch: 002/010 | Batch 500/938 | Cost: 0.2395\n",
            "Epoch: 002/010 | Batch 550/938 | Cost: 0.2064\n",
            "Epoch: 002/010 | Batch 600/938 | Cost: 0.1637\n",
            "Epoch: 002/010 | Batch 650/938 | Cost: 0.1724\n",
            "Epoch: 002/010 | Batch 700/938 | Cost: 0.2009\n",
            "Epoch: 002/010 | Batch 750/938 | Cost: 0.0853\n",
            "Epoch: 002/010 | Batch 800/938 | Cost: 0.1704\n",
            "Epoch: 002/010 | Batch 850/938 | Cost: 0.1077\n",
            "Epoch: 002/010 | Batch 900/938 | Cost: 0.1550\n",
            "Epoch: 002/010 training accuracy: 96.62%\n",
            "Time elapsed: 0.46 min\n",
            "Epoch: 003/010 | Batch 000/938 | Cost: 0.1078\n",
            "Epoch: 003/010 | Batch 050/938 | Cost: 0.0976\n",
            "Epoch: 003/010 | Batch 100/938 | Cost: 0.0772\n",
            "Epoch: 003/010 | Batch 150/938 | Cost: 0.1943\n",
            "Epoch: 003/010 | Batch 200/938 | Cost: 0.1211\n",
            "Epoch: 003/010 | Batch 250/938 | Cost: 0.2548\n",
            "Epoch: 003/010 | Batch 300/938 | Cost: 0.0644\n",
            "Epoch: 003/010 | Batch 350/938 | Cost: 0.1130\n",
            "Epoch: 003/010 | Batch 400/938 | Cost: 0.0820\n",
            "Epoch: 003/010 | Batch 450/938 | Cost: 0.0835\n",
            "Epoch: 003/010 | Batch 500/938 | Cost: 0.0741\n",
            "Epoch: 003/010 | Batch 550/938 | Cost: 0.0485\n",
            "Epoch: 003/010 | Batch 600/938 | Cost: 0.1994\n",
            "Epoch: 003/010 | Batch 650/938 | Cost: 0.1151\n",
            "Epoch: 003/010 | Batch 700/938 | Cost: 0.2079\n",
            "Epoch: 003/010 | Batch 750/938 | Cost: 0.0624\n",
            "Epoch: 003/010 | Batch 800/938 | Cost: 0.0201\n",
            "Epoch: 003/010 | Batch 850/938 | Cost: 0.0175\n",
            "Epoch: 003/010 | Batch 900/938 | Cost: 0.0662\n",
            "Epoch: 003/010 training accuracy: 97.70%\n",
            "Time elapsed: 0.69 min\n",
            "Epoch: 004/010 | Batch 000/938 | Cost: 0.0064\n",
            "Epoch: 004/010 | Batch 050/938 | Cost: 0.0339\n",
            "Epoch: 004/010 | Batch 100/938 | Cost: 0.1611\n",
            "Epoch: 004/010 | Batch 150/938 | Cost: 0.0588\n",
            "Epoch: 004/010 | Batch 200/938 | Cost: 0.0341\n",
            "Epoch: 004/010 | Batch 250/938 | Cost: 0.1547\n",
            "Epoch: 004/010 | Batch 300/938 | Cost: 0.2144\n",
            "Epoch: 004/010 | Batch 350/938 | Cost: 0.0218\n",
            "Epoch: 004/010 | Batch 400/938 | Cost: 0.1184\n",
            "Epoch: 004/010 | Batch 450/938 | Cost: 0.1752\n",
            "Epoch: 004/010 | Batch 500/938 | Cost: 0.1361\n",
            "Epoch: 004/010 | Batch 550/938 | Cost: 0.1090\n",
            "Epoch: 004/010 | Batch 600/938 | Cost: 0.0370\n",
            "Epoch: 004/010 | Batch 650/938 | Cost: 0.0513\n",
            "Epoch: 004/010 | Batch 700/938 | Cost: 0.1914\n",
            "Epoch: 004/010 | Batch 750/938 | Cost: 0.1602\n",
            "Epoch: 004/010 | Batch 800/938 | Cost: 0.0733\n",
            "Epoch: 004/010 | Batch 850/938 | Cost: 0.0516\n",
            "Epoch: 004/010 | Batch 900/938 | Cost: 0.1485\n",
            "Epoch: 004/010 training accuracy: 97.97%\n",
            "Time elapsed: 0.93 min\n",
            "Epoch: 005/010 | Batch 000/938 | Cost: 0.0238\n",
            "Epoch: 005/010 | Batch 050/938 | Cost: 0.0148\n",
            "Epoch: 005/010 | Batch 100/938 | Cost: 0.1155\n",
            "Epoch: 005/010 | Batch 150/938 | Cost: 0.0689\n",
            "Epoch: 005/010 | Batch 200/938 | Cost: 0.0433\n",
            "Epoch: 005/010 | Batch 250/938 | Cost: 0.0794\n",
            "Epoch: 005/010 | Batch 300/938 | Cost: 0.0672\n",
            "Epoch: 005/010 | Batch 350/938 | Cost: 0.0883\n",
            "Epoch: 005/010 | Batch 400/938 | Cost: 0.0380\n",
            "Epoch: 005/010 | Batch 450/938 | Cost: 0.0508\n",
            "Epoch: 005/010 | Batch 500/938 | Cost: 0.1399\n",
            "Epoch: 005/010 | Batch 550/938 | Cost: 0.1173\n",
            "Epoch: 005/010 | Batch 600/938 | Cost: 0.0662\n",
            "Epoch: 005/010 | Batch 650/938 | Cost: 0.0267\n",
            "Epoch: 005/010 | Batch 700/938 | Cost: 0.0606\n",
            "Epoch: 005/010 | Batch 750/938 | Cost: 0.0460\n",
            "Epoch: 005/010 | Batch 800/938 | Cost: 0.0441\n",
            "Epoch: 005/010 | Batch 850/938 | Cost: 0.1089\n",
            "Epoch: 005/010 | Batch 900/938 | Cost: 0.0349\n",
            "Epoch: 005/010 training accuracy: 98.58%\n",
            "Time elapsed: 1.16 min\n",
            "Epoch: 006/010 | Batch 000/938 | Cost: 0.1055\n",
            "Epoch: 006/010 | Batch 050/938 | Cost: 0.0898\n",
            "Epoch: 006/010 | Batch 100/938 | Cost: 0.0692\n",
            "Epoch: 006/010 | Batch 150/938 | Cost: 0.0287\n",
            "Epoch: 006/010 | Batch 200/938 | Cost: 0.0411\n",
            "Epoch: 006/010 | Batch 250/938 | Cost: 0.0761\n",
            "Epoch: 006/010 | Batch 300/938 | Cost: 0.0850\n",
            "Epoch: 006/010 | Batch 350/938 | Cost: 0.0616\n",
            "Epoch: 006/010 | Batch 400/938 | Cost: 0.1172\n",
            "Epoch: 006/010 | Batch 450/938 | Cost: 0.0509\n",
            "Epoch: 006/010 | Batch 500/938 | Cost: 0.0706\n",
            "Epoch: 006/010 | Batch 550/938 | Cost: 0.1720\n",
            "Epoch: 006/010 | Batch 600/938 | Cost: 0.0539\n",
            "Epoch: 006/010 | Batch 650/938 | Cost: 0.0357\n",
            "Epoch: 006/010 | Batch 700/938 | Cost: 0.0727\n",
            "Epoch: 006/010 | Batch 750/938 | Cost: 0.1152\n",
            "Epoch: 006/010 | Batch 800/938 | Cost: 0.0567\n",
            "Epoch: 006/010 | Batch 850/938 | Cost: 0.0786\n",
            "Epoch: 006/010 | Batch 900/938 | Cost: 0.1787\n",
            "Epoch: 006/010 training accuracy: 98.83%\n",
            "Time elapsed: 1.39 min\n",
            "Epoch: 007/010 | Batch 000/938 | Cost: 0.0048\n",
            "Epoch: 007/010 | Batch 050/938 | Cost: 0.0899\n",
            "Epoch: 007/010 | Batch 100/938 | Cost: 0.0349\n",
            "Epoch: 007/010 | Batch 150/938 | Cost: 0.1163\n",
            "Epoch: 007/010 | Batch 200/938 | Cost: 0.0259\n",
            "Epoch: 007/010 | Batch 250/938 | Cost: 0.0423\n",
            "Epoch: 007/010 | Batch 300/938 | Cost: 0.0619\n",
            "Epoch: 007/010 | Batch 350/938 | Cost: 0.0347\n",
            "Epoch: 007/010 | Batch 400/938 | Cost: 0.0254\n",
            "Epoch: 007/010 | Batch 450/938 | Cost: 0.0486\n",
            "Epoch: 007/010 | Batch 500/938 | Cost: 0.0165\n",
            "Epoch: 007/010 | Batch 550/938 | Cost: 0.0834\n",
            "Epoch: 007/010 | Batch 600/938 | Cost: 0.1113\n",
            "Epoch: 007/010 | Batch 650/938 | Cost: 0.0823\n",
            "Epoch: 007/010 | Batch 700/938 | Cost: 0.0291\n",
            "Epoch: 007/010 | Batch 750/938 | Cost: 0.0654\n",
            "Epoch: 007/010 | Batch 800/938 | Cost: 0.0155\n",
            "Epoch: 007/010 | Batch 850/938 | Cost: 0.0663\n",
            "Epoch: 007/010 | Batch 900/938 | Cost: 0.0562\n",
            "Epoch: 007/010 training accuracy: 99.00%\n",
            "Time elapsed: 1.62 min\n",
            "Epoch: 008/010 | Batch 000/938 | Cost: 0.0232\n",
            "Epoch: 008/010 | Batch 050/938 | Cost: 0.0175\n",
            "Epoch: 008/010 | Batch 100/938 | Cost: 0.0415\n",
            "Epoch: 008/010 | Batch 150/938 | Cost: 0.0376\n",
            "Epoch: 008/010 | Batch 200/938 | Cost: 0.0077\n",
            "Epoch: 008/010 | Batch 250/938 | Cost: 0.0588\n",
            "Epoch: 008/010 | Batch 300/938 | Cost: 0.0243\n",
            "Epoch: 008/010 | Batch 350/938 | Cost: 0.0342\n",
            "Epoch: 008/010 | Batch 400/938 | Cost: 0.0108\n",
            "Epoch: 008/010 | Batch 450/938 | Cost: 0.0051\n",
            "Epoch: 008/010 | Batch 500/938 | Cost: 0.0127\n",
            "Epoch: 008/010 | Batch 550/938 | Cost: 0.0043\n",
            "Epoch: 008/010 | Batch 600/938 | Cost: 0.0184\n",
            "Epoch: 008/010 | Batch 650/938 | Cost: 0.0423\n",
            "Epoch: 008/010 | Batch 700/938 | Cost: 0.0481\n",
            "Epoch: 008/010 | Batch 750/938 | Cost: 0.0171\n",
            "Epoch: 008/010 | Batch 800/938 | Cost: 0.0383\n",
            "Epoch: 008/010 | Batch 850/938 | Cost: 0.0053\n",
            "Epoch: 008/010 | Batch 900/938 | Cost: 0.0519\n",
            "Epoch: 008/010 training accuracy: 99.24%\n",
            "Time elapsed: 1.85 min\n",
            "Epoch: 009/010 | Batch 000/938 | Cost: 0.0401\n",
            "Epoch: 009/010 | Batch 050/938 | Cost: 0.0586\n",
            "Epoch: 009/010 | Batch 100/938 | Cost: 0.0111\n",
            "Epoch: 009/010 | Batch 150/938 | Cost: 0.0157\n",
            "Epoch: 009/010 | Batch 200/938 | Cost: 0.0814\n",
            "Epoch: 009/010 | Batch 250/938 | Cost: 0.0054\n",
            "Epoch: 009/010 | Batch 300/938 | Cost: 0.0067\n",
            "Epoch: 009/010 | Batch 350/938 | Cost: 0.0063\n",
            "Epoch: 009/010 | Batch 400/938 | Cost: 0.0482\n",
            "Epoch: 009/010 | Batch 450/938 | Cost: 0.0184\n",
            "Epoch: 009/010 | Batch 500/938 | Cost: 0.0251\n",
            "Epoch: 009/010 | Batch 550/938 | Cost: 0.0136\n",
            "Epoch: 009/010 | Batch 600/938 | Cost: 0.0953\n",
            "Epoch: 009/010 | Batch 650/938 | Cost: 0.0502\n",
            "Epoch: 009/010 | Batch 700/938 | Cost: 0.0194\n",
            "Epoch: 009/010 | Batch 750/938 | Cost: 0.0590\n",
            "Epoch: 009/010 | Batch 800/938 | Cost: 0.0069\n",
            "Epoch: 009/010 | Batch 850/938 | Cost: 0.1538\n",
            "Epoch: 009/010 | Batch 900/938 | Cost: 0.0026\n",
            "Epoch: 009/010 training accuracy: 99.30%\n",
            "Time elapsed: 2.08 min\n",
            "Epoch: 010/010 | Batch 000/938 | Cost: 0.0189\n",
            "Epoch: 010/010 | Batch 050/938 | Cost: 0.0171\n",
            "Epoch: 010/010 | Batch 100/938 | Cost: 0.0133\n",
            "Epoch: 010/010 | Batch 150/938 | Cost: 0.0115\n",
            "Epoch: 010/010 | Batch 200/938 | Cost: 0.0386\n",
            "Epoch: 010/010 | Batch 250/938 | Cost: 0.0223\n",
            "Epoch: 010/010 | Batch 300/938 | Cost: 0.0116\n",
            "Epoch: 010/010 | Batch 350/938 | Cost: 0.0171\n",
            "Epoch: 010/010 | Batch 400/938 | Cost: 0.0016\n",
            "Epoch: 010/010 | Batch 450/938 | Cost: 0.0102\n",
            "Epoch: 010/010 | Batch 500/938 | Cost: 0.0840\n",
            "Epoch: 010/010 | Batch 550/938 | Cost: 0.0136\n",
            "Epoch: 010/010 | Batch 600/938 | Cost: 0.0291\n",
            "Epoch: 010/010 | Batch 650/938 | Cost: 0.0499\n",
            "Epoch: 010/010 | Batch 700/938 | Cost: 0.0169\n",
            "Epoch: 010/010 | Batch 750/938 | Cost: 0.0105\n",
            "Epoch: 010/010 | Batch 800/938 | Cost: 0.0099\n",
            "Epoch: 010/010 | Batch 850/938 | Cost: 0.0099\n",
            "Epoch: 010/010 | Batch 900/938 | Cost: 0.0388\n",
            "Epoch: 010/010 training accuracy: 99.00%\n",
            "Time elapsed: 2.32 min\n",
            "Total Training Time: 2.32 min\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rsBuY3gFLl2L",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "7a7c398f-8a22-4586-fe3d-ab4e0be0967c"
      },
      "source": [
        "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Test accuracy: 97.50%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oIb8fEkvLl2Q",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 69
        },
        "outputId": "d0230a68-83a3-4411-c202-da9045145d91"
      },
      "source": [
        "%watermark -iv"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "numpy 1.18.5\n",
            "torch 1.5.1+cu101\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}